{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Matrix Factorization\n",
    "\n",
    "In a recommendation system, there is a group of users and a set of items. Given that each users have rated some items in the system, we would like to predict how the users would rate the items that they have not yet rated, such that we can make recommendations to the users.\n",
    "\n",
    "Matrix factorization is one of the mainly used algorithm in recommendation systems. It can be used to discover latent features underlying the interactions between two different kinds of entities.\n",
    "\n",
    "Assume we assign a k-dimensional vector to each user and a k-dimensional vector to each item such that the dot product of these two vectors gives the user's rating of that item. We can learn the user and item vectors directly, which is essentially performing SVD on the user-item matrix. We can also try to learn the latent features using multi-layer neural networks. \n",
    "\n",
    "In this tutorial, we will work though the steps to implement these ideas in MXNet."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Data\n",
    "\n",
    "We use the [MovieLens](http://grouplens.org/datasets/movielens/) data here, but it can apply to other datasets as well. Each row of this dataset contains a tuple of user id, movie id, rating, and time stamp, we will only use the first three items. We first define the a batch which contains n tuples. It also provides name and shape information to MXNet about the data and label. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Batch(object):\n",
    "    def __init__(self, data_names, data, label_names, label):\n",
    "        self.data = data\n",
    "        self.label = label\n",
    "        self.data_names = data_names\n",
    "        self.label_names = label_names\n",
    "        \n",
    "    @property\n",
    "    def provide_data(self):\n",
    "        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]\n",
    "    \n",
    "    @property\n",
    "    def provide_label(self):\n",
    "        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we define a data iterator, which returns a batch of tuples each time. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "import random\n",
    "\n",
    "class Batch(object):\n",
    "    def __init__(self, data_names, data, label_names, label):\n",
    "        self.data = data\n",
    "        self.label = label\n",
    "        self.data_names = data_names\n",
    "        self.label_names = label_names\n",
    "\n",
    "    @property\n",
    "    def provide_data(self):\n",
    "        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]\n",
    "\n",
    "    @property\n",
    "    def provide_label(self):\n",
    "        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]\n",
    "\n",
    "class DataIter(mx.io.DataIter):\n",
    "    def __init__(self, fname, batch_size):\n",
    "        super(DataIter, self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.data = []\n",
    "        for line in file(fname):\n",
    "            tks = line.strip().split('\\t')\n",
    "            if len(tks) != 4:\n",
    "                continue\n",
    "            self.data.append((int(tks[0]), int(tks[1]), float(tks[2])))\n",
    "        self.provide_data = [('user', (batch_size, )), ('item', (batch_size, ))]\n",
    "        self.provide_label = [('score', (self.batch_size, ))]\n",
    "\n",
    "    def __iter__(self):\n",
    "        for k in range(len(self.data) / self.batch_size):\n",
    "            users = []\n",
    "            items = []\n",
    "            scores = []\n",
    "            for i in range(self.batch_size):\n",
    "                j = k * self.batch_size + i\n",
    "                user, item, score = self.data[j]\n",
    "                users.append(user)\n",
    "                items.append(item)\n",
    "                scores.append(score)\n",
    "\n",
    "            data_all = [mx.nd.array(users), mx.nd.array(items)]\n",
    "            label_all = [mx.nd.array(scores)]\n",
    "            data_names = ['user', 'item']\n",
    "            label_names = ['score']\n",
    "\n",
    "            data_batch = Batch(data_names, data_all, label_names, label_all)\n",
    "            yield data_batch\n",
    "\n",
    "    def reset(self):\n",
    "        random.shuffle(self.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we download the data and provide a function to obtain the data iterator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import urllib\n",
    "import zipfile\n",
    "if not os.path.exists('ml-100k.zip'):\n",
    "    urllib.urlretrieve('http://files.grouplens.org/datasets/movielens/ml-100k.zip', 'ml-100k.zip')\n",
    "with zipfile.ZipFile(\"ml-100k.zip\",\"r\") as f:\n",
    "    f.extractall(\"./\")\n",
    "def get_data(batch_size):\n",
    "    return (DataIter('./ml-100k/u1.base', batch_size), DataIter('./ml-100k/u1.test', batch_size))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally we calculate the numbers of users and items for later use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(944, 1683)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def max_id(fname):\n",
    "    mu = 0\n",
    "    mi = 0\n",
    "    for line in file(fname):\n",
    "        tks = line.strip().split('\\t')\n",
    "        if len(tks) != 4:\n",
    "            continue\n",
    "        mu = max(mu, int(tks[0]))\n",
    "        mi = max(mi, int(tks[1]))\n",
    "    return mu + 1, mi + 1\n",
    "max_user, max_item = max_id('./ml-100k/u.data')\n",
    "(max_user, max_item)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimization\n",
    "\n",
    "We first implement the RMSE (root-mean-square error) measurement, which is commonly used by matrix factorization. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import math\n",
    "def RMSE(label, pred):\n",
    "    ret = 0.0\n",
    "    n = 0.0\n",
    "    pred = pred.flatten()\n",
    "    for i in range(len(label)):\n",
    "        ret += (label[i] - pred[i]) * (label[i] - pred[i])\n",
    "        n += 1.0\n",
    "    return math.sqrt(ret / n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we define a general training module, which is borrowed from the image classification application. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(network, batch_size, num_epoch, learning_rate):\n",
    "    batch_size = batch_size\n",
    "    train, test = get_data(batch_size)\n",
    "    model = mx.mod.Module(symbol = network, \n",
    "                          data_names=[x[0] for x in train.provide_data],\n",
    "                          label_names=[y[0] for y in train.provide_label],\n",
    "                          context=[mx.gpu(0)])\n",
    "\n",
    "    import logging\n",
    "    head = '%(asctime)-15s %(message)s'\n",
    "    logging.basicConfig(level=logging.DEBUG)\n",
    "\n",
    "    model.fit(train_data = train, \n",
    "              eval_data = test,\n",
    "              num_epoch=num_epoch,\n",
    "              optimizer='sgd',\n",
    "              optimizer_params={'learning_rate':learning_rate, 'momentum':0.9, 'wd':0.0001},\n",
    "              eval_metric = RMSE,\n",
    "              batch_end_callback=mx.callback.Speedometer(batch_size, 20000/batch_size),)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Networks\n",
    "\n",
    "Now we try various networks. We first learn the latent vectors directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Start training with [gpu(0)]\n",
      "INFO:root:Epoch[0] Batch [312]\tSpeed: 41428.85 samples/sec\tTrain-RMSE=3.699369\n",
      "INFO:root:Epoch[0] Batch [624]\tSpeed: 42144.73 samples/sec\tTrain-RMSE=3.702055\n",
      "INFO:root:Epoch[0] Batch [936]\tSpeed: 42175.25 samples/sec\tTrain-RMSE=3.694050\n",
      "INFO:root:Epoch[0] Batch [1248]\tSpeed: 42307.85 samples/sec\tTrain-RMSE=3.698962\n",
      "INFO:root:Epoch[0] Resetting Data Iterator\n",
      "INFO:root:Epoch[0] Time cost=2.174\n",
      "INFO:root:Epoch[0] Validation-RMSE=3.714925\n",
      "INFO:root:Epoch[1] Batch [312]\tSpeed: 42458.83 samples/sec\tTrain-RMSE=3.680307\n",
      "INFO:root:Epoch[1] Batch [624]\tSpeed: 42960.69 samples/sec\tTrain-RMSE=3.627375\n",
      "INFO:root:Epoch[1] Batch [936]\tSpeed: 42684.44 samples/sec\tTrain-RMSE=3.283429\n",
      "INFO:root:Epoch[1] Batch [1248]\tSpeed: 42694.95 samples/sec\tTrain-RMSE=2.586398\n",
      "INFO:root:Epoch[1] Resetting Data Iterator\n",
      "INFO:root:Epoch[1] Time cost=1.905\n",
      "INFO:root:Epoch[1] Validation-RMSE=2.448542\n",
      "INFO:root:Epoch[2] Batch [312]\tSpeed: 43118.68 samples/sec\tTrain-RMSE=2.025510\n",
      "INFO:root:Epoch[2] Batch [624]\tSpeed: 42748.51 samples/sec\tTrain-RMSE=1.712786\n",
      "INFO:root:Epoch[2] Batch [936]\tSpeed: 42972.15 samples/sec\tTrain-RMSE=1.505154\n",
      "INFO:root:Epoch[2] Batch [1248]\tSpeed: 42779.65 samples/sec\tTrain-RMSE=1.376185\n",
      "INFO:root:Epoch[2] Resetting Data Iterator\n",
      "INFO:root:Epoch[2] Time cost=1.897\n",
      "INFO:root:Epoch[2] Validation-RMSE=1.432550\n",
      "INFO:root:Epoch[3] Batch [312]\tSpeed: 43192.26 samples/sec\tTrain-RMSE=1.266802\n",
      "INFO:root:Epoch[3] Batch [624]\tSpeed: 43051.19 samples/sec\tTrain-RMSE=1.218756\n",
      "INFO:root:Epoch[3] Batch [936]\tSpeed: 42522.01 samples/sec\tTrain-RMSE=1.169181\n",
      "INFO:root:Epoch[3] Batch [1248]\tSpeed: 42781.29 samples/sec\tTrain-RMSE=1.142398\n",
      "INFO:root:Epoch[3] Resetting Data Iterator\n",
      "INFO:root:Epoch[3] Time cost=1.898\n",
      "INFO:root:Epoch[3] Validation-RMSE=1.202146\n",
      "INFO:root:Epoch[4] Batch [312]\tSpeed: 42951.35 samples/sec\tTrain-RMSE=1.094332\n",
      "INFO:root:Epoch[4] Batch [624]\tSpeed: 42675.15 samples/sec\tTrain-RMSE=1.085629\n",
      "INFO:root:Epoch[4] Batch [936]\tSpeed: 42919.15 samples/sec\tTrain-RMSE=1.073905\n",
      "INFO:root:Epoch[4] Batch [1248]\tSpeed: 42915.65 samples/sec\tTrain-RMSE=1.056360\n",
      "INFO:root:Epoch[4] Resetting Data Iterator\n",
      "INFO:root:Epoch[4] Time cost=1.898\n",
      "INFO:root:Epoch[4] Validation-RMSE=1.116253\n",
      "INFO:root:Epoch[5] Batch [312]\tSpeed: 43046.65 samples/sec\tTrain-RMSE=1.031064\n",
      "INFO:root:Epoch[5] Batch [624]\tSpeed: 43046.10 samples/sec\tTrain-RMSE=1.031983\n",
      "INFO:root:Epoch[5] Batch [936]\tSpeed: 42901.89 samples/sec\tTrain-RMSE=1.024316\n",
      "INFO:root:Epoch[5] Batch [1248]\tSpeed: 43107.51 samples/sec\tTrain-RMSE=1.028240\n",
      "INFO:root:Epoch[5] Resetting Data Iterator\n",
      "INFO:root:Epoch[5] Time cost=1.892\n",
      "INFO:root:Epoch[5] Validation-RMSE=1.073396\n",
      "INFO:root:Epoch[6] Batch [312]\tSpeed: 43002.51 samples/sec\tTrain-RMSE=1.010509\n",
      "INFO:root:Epoch[6] Batch [624]\tSpeed: 42726.47 samples/sec\tTrain-RMSE=1.003917\n",
      "INFO:root:Epoch[6] Batch [936]\tSpeed: 42738.72 samples/sec\tTrain-RMSE=1.002835\n",
      "INFO:root:Epoch[6] Batch [1248]\tSpeed: 42982.32 samples/sec\tTrain-RMSE=1.000915\n",
      "INFO:root:Epoch[6] Resetting Data Iterator\n",
      "INFO:root:Epoch[6] Time cost=1.898\n",
      "INFO:root:Epoch[6] Validation-RMSE=1.049511\n",
      "INFO:root:Epoch[7] Batch [312]\tSpeed: 43566.30 samples/sec\tTrain-RMSE=0.985503\n",
      "INFO:root:Epoch[7] Batch [624]\tSpeed: 43287.67 samples/sec\tTrain-RMSE=0.995254\n",
      "INFO:root:Epoch[7] Batch [936]\tSpeed: 43351.22 samples/sec\tTrain-RMSE=0.993624\n",
      "INFO:root:Epoch[7] Batch [1248]\tSpeed: 43160.14 samples/sec\tTrain-RMSE=0.988817\n",
      "INFO:root:Epoch[7] Resetting Data Iterator\n",
      "INFO:root:Epoch[7] Time cost=1.877\n",
      "INFO:root:Epoch[7] Validation-RMSE=1.035829\n",
      "INFO:root:Epoch[8] Batch [312]\tSpeed: 43220.97 samples/sec\tTrain-RMSE=0.984125\n",
      "INFO:root:Epoch[8] Batch [624]\tSpeed: 43146.06 samples/sec\tTrain-RMSE=0.982134\n",
      "INFO:root:Epoch[8] Batch [936]\tSpeed: 43274.74 samples/sec\tTrain-RMSE=0.978188\n",
      "INFO:root:Epoch[8] Batch [1248]\tSpeed: 43234.06 samples/sec\tTrain-RMSE=0.976167\n",
      "INFO:root:Epoch[8] Resetting Data Iterator\n",
      "INFO:root:Epoch[8] Time cost=1.885\n",
      "INFO:root:Epoch[8] Validation-RMSE=1.025111\n",
      "INFO:root:Epoch[9] Batch [312]\tSpeed: 43331.84 samples/sec\tTrain-RMSE=0.965266\n",
      "INFO:root:Epoch[9] Batch [624]\tSpeed: 43185.53 samples/sec\tTrain-RMSE=0.984249\n",
      "INFO:root:Epoch[9] Batch [936]\tSpeed: 43261.89 samples/sec\tTrain-RMSE=0.972151\n",
      "INFO:root:Epoch[9] Batch [1248]\tSpeed: 43151.76 samples/sec\tTrain-RMSE=0.972135\n",
      "INFO:root:Epoch[9] Resetting Data Iterator\n",
      "INFO:root:Epoch[9] Time cost=1.883\n",
      "INFO:root:Epoch[9] Validation-RMSE=1.016984\n"
     ]
    }
   ],
   "source": [
    "# @@@ AUTOTEST_OUTPUT_IGNORED_CELL\n",
    "def plain_net(k):\n",
    "    # input\n",
    "    user = mx.symbol.Variable('user')\n",
    "    item = mx.symbol.Variable('item')\n",
    "    score = mx.symbol.Variable('score')\n",
    "    # user feature lookup\n",
    "    user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k) \n",
    "    # item feature lookup\n",
    "    item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)\n",
    "    # predict by the inner product, which is elementwise product and then sum\n",
    "    pred = user * item\n",
    "    pred = mx.symbol.sum_axis(data = pred, axis = 1)\n",
    "    pred = mx.symbol.Flatten(data = pred)\n",
    "    # loss layer\n",
    "    pred = mx.symbol.LinearRegressionOutput(data = pred, label = score)\n",
    "    return pred\n",
    "\n",
    "train(plain_net(64), batch_size=64, num_epoch=10, learning_rate=.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we try to use 2 layers neural network to learn the latent variables, which stack a fully connected layer above the embedding layers: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Start training with [gpu(0)]\n",
      "INFO:root:Epoch[0] Batch [312]\tSpeed: 29358.84 samples/sec\tTrain-RMSE=1.338036\n",
      "INFO:root:Epoch[0] Batch [624]\tSpeed: 29040.21 samples/sec\tTrain-RMSE=1.030033\n",
      "INFO:root:Epoch[0] Batch [936]\tSpeed: 30795.04 samples/sec\tTrain-RMSE=1.005132\n",
      "INFO:root:Epoch[0] Batch [1248]\tSpeed: 29430.40 samples/sec\tTrain-RMSE=0.987999\n",
      "INFO:root:Epoch[0] Resetting Data Iterator\n",
      "INFO:root:Epoch[0] Time cost=2.733\n",
      "INFO:root:Epoch[0] Validation-RMSE=0.990343\n",
      "INFO:root:Epoch[1] Batch [312]\tSpeed: 29327.44 samples/sec\tTrain-RMSE=0.970059\n",
      "INFO:root:Epoch[1] Batch [624]\tSpeed: 29056.96 samples/sec\tTrain-RMSE=0.963783\n",
      "INFO:root:Epoch[1] Batch [936]\tSpeed: 28939.17 samples/sec\tTrain-RMSE=0.966315\n",
      "INFO:root:Epoch[1] Batch [1248]\tSpeed: 29150.33 samples/sec\tTrain-RMSE=0.970925\n",
      "INFO:root:Epoch[1] Resetting Data Iterator\n",
      "INFO:root:Epoch[1] Time cost=2.781\n",
      "INFO:root:Epoch[1] Validation-RMSE=0.978711\n",
      "INFO:root:Epoch[2] Batch [312]\tSpeed: 29850.39 samples/sec\tTrain-RMSE=0.949423\n",
      "INFO:root:Epoch[2] Batch [624]\tSpeed: 29131.78 samples/sec\tTrain-RMSE=0.950519\n",
      "INFO:root:Epoch[2] Batch [936]\tSpeed: 29222.54 samples/sec\tTrain-RMSE=0.953821\n",
      "INFO:root:Epoch[2] Batch [1248]\tSpeed: 29144.28 samples/sec\tTrain-RMSE=0.957096\n",
      "INFO:root:Epoch[2] Resetting Data Iterator\n",
      "INFO:root:Epoch[2] Time cost=2.762\n",
      "INFO:root:Epoch[2] Validation-RMSE=0.967527\n",
      "INFO:root:Epoch[3] Batch [312]\tSpeed: 31087.31 samples/sec\tTrain-RMSE=0.938336\n",
      "INFO:root:Epoch[3] Batch [624]\tSpeed: 30903.51 samples/sec\tTrain-RMSE=0.940157\n",
      "INFO:root:Epoch[3] Batch [936]\tSpeed: 30788.64 samples/sec\tTrain-RMSE=0.961417\n",
      "INFO:root:Epoch[3] Batch [1248]\tSpeed: 29309.31 samples/sec\tTrain-RMSE=0.957611\n",
      "INFO:root:Epoch[3] Resetting Data Iterator\n",
      "INFO:root:Epoch[3] Time cost=2.654\n",
      "INFO:root:Epoch[3] Validation-RMSE=0.965186\n",
      "INFO:root:Epoch[4] Batch [312]\tSpeed: 30882.11 samples/sec\tTrain-RMSE=0.941631\n",
      "INFO:root:Epoch[4] Batch [624]\tSpeed: 30760.84 samples/sec\tTrain-RMSE=0.944820\n",
      "INFO:root:Epoch[4] Batch [936]\tSpeed: 30836.56 samples/sec\tTrain-RMSE=0.947041\n",
      "INFO:root:Epoch[4] Batch [1248]\tSpeed: 30889.28 samples/sec\tTrain-RMSE=0.958895\n",
      "INFO:root:Epoch[4] Resetting Data Iterator\n",
      "INFO:root:Epoch[4] Time cost=2.625\n",
      "INFO:root:Epoch[4] Validation-RMSE=1.021695\n",
      "INFO:root:Epoch[5] Batch [312]\tSpeed: 31184.36 samples/sec\tTrain-RMSE=0.939983\n",
      "INFO:root:Epoch[5] Batch [624]\tSpeed: 31011.36 samples/sec\tTrain-RMSE=0.942205\n",
      "INFO:root:Epoch[5] Batch [936]\tSpeed: 30927.93 samples/sec\tTrain-RMSE=0.948742\n",
      "INFO:root:Epoch[5] Batch [1248]\tSpeed: 31253.58 samples/sec\tTrain-RMSE=0.946025\n",
      "INFO:root:Epoch[5] Resetting Data Iterator\n",
      "INFO:root:Epoch[5] Time cost=2.604\n",
      "INFO:root:Epoch[5] Validation-RMSE=0.971389\n",
      "INFO:root:Epoch[6] Batch [312]\tSpeed: 28975.54 samples/sec\tTrain-RMSE=0.944098\n",
      "INFO:root:Epoch[6] Batch [624]\tSpeed: 29132.85 samples/sec\tTrain-RMSE=0.948457\n",
      "INFO:root:Epoch[6] Batch [936]\tSpeed: 29212.85 samples/sec\tTrain-RMSE=0.939584\n",
      "INFO:root:Epoch[6] Batch [1248]\tSpeed: 29189.70 samples/sec\tTrain-RMSE=0.949756\n",
      "INFO:root:Epoch[6] Resetting Data Iterator\n",
      "INFO:root:Epoch[6] Time cost=2.781\n",
      "INFO:root:Epoch[6] Validation-RMSE=0.979447\n",
      "INFO:root:Epoch[7] Batch [312]\tSpeed: 29228.83 samples/sec\tTrain-RMSE=0.933935\n",
      "INFO:root:Epoch[7] Batch [624]\tSpeed: 29092.05 samples/sec\tTrain-RMSE=0.944123\n",
      "INFO:root:Epoch[7] Batch [936]\tSpeed: 29902.46 samples/sec\tTrain-RMSE=0.943008\n",
      "INFO:root:Epoch[7] Batch [1248]\tSpeed: 31193.96 samples/sec\tTrain-RMSE=0.952697\n",
      "INFO:root:Epoch[7] Resetting Data Iterator\n",
      "INFO:root:Epoch[7] Time cost=2.714\n",
      "INFO:root:Epoch[7] Validation-RMSE=0.966694\n",
      "INFO:root:Epoch[8] Batch [312]\tSpeed: 29643.79 samples/sec\tTrain-RMSE=0.941655\n",
      "INFO:root:Epoch[8] Batch [624]\tSpeed: 29277.61 samples/sec\tTrain-RMSE=0.932646\n",
      "INFO:root:Epoch[8] Batch [936]\tSpeed: 29298.30 samples/sec\tTrain-RMSE=0.948542\n",
      "INFO:root:Epoch[8] Batch [1248]\tSpeed: 29119.72 samples/sec\tTrain-RMSE=0.936999\n",
      "INFO:root:Epoch[8] Resetting Data Iterator\n",
      "INFO:root:Epoch[8] Time cost=2.760\n",
      "INFO:root:Epoch[8] Validation-RMSE=0.964616\n",
      "INFO:root:Epoch[9] Batch [312]\tSpeed: 29232.17 samples/sec\tTrain-RMSE=0.934462\n",
      "INFO:root:Epoch[9] Batch [624]\tSpeed: 29045.67 samples/sec\tTrain-RMSE=0.946235\n",
      "INFO:root:Epoch[9] Batch [936]\tSpeed: 29202.37 samples/sec\tTrain-RMSE=0.943892\n",
      "INFO:root:Epoch[9] Batch [1248]\tSpeed: 29250.80 samples/sec\tTrain-RMSE=0.939333\n",
      "INFO:root:Epoch[9] Resetting Data Iterator\n",
      "INFO:root:Epoch[9] Time cost=2.779\n",
      "INFO:root:Epoch[9] Validation-RMSE=1.057628\n"
     ]
    }
   ],
   "source": [
    "# @@@ AUTOTEST_OUTPUT_IGNORED_CELL\n",
    "def get_one_layer_mlp(hidden, k):\n",
    "    # input\n",
    "    user = mx.symbol.Variable('user')\n",
    "    item = mx.symbol.Variable('item')\n",
    "    score = mx.symbol.Variable('score')\n",
    "    # user latent features\n",
    "    user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)\n",
    "    user = mx.symbol.Activation(data = user, act_type=\"relu\")\n",
    "    user = mx.symbol.FullyConnected(data = user, num_hidden = hidden)\n",
    "    # item latent features\n",
    "    item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)\n",
    "    item = mx.symbol.Activation(data = item, act_type=\"relu\")\n",
    "    item = mx.symbol.FullyConnected(data = item, num_hidden = hidden)\n",
    "    # predict by the inner product\n",
    "    pred = user * item\n",
    "    pred = mx.symbol.sum_axis(data = pred, axis = 1)\n",
    "    pred = mx.symbol.Flatten(data = pred)\n",
    "    # loss layer\n",
    "    pred = mx.symbol.LinearRegressionOutput(data = pred, label = score)\n",
    "    return pred\n",
    "\n",
    "train(get_one_layer_mlp(64, 64), batch_size=64, num_epoch=10, learning_rate=.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adding dropout layers to relief the over-fitting. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Start training with [gpu(0)]\n",
      "INFO:root:Epoch[0] Batch [312]\tSpeed: 31588.74 samples/sec\tTrain-RMSE=1.284921\n",
      "INFO:root:Epoch[0] Batch [624]\tSpeed: 32668.93 samples/sec\tTrain-RMSE=1.007235\n",
      "INFO:root:Epoch[0] Batch [936]\tSpeed: 32648.85 samples/sec\tTrain-RMSE=0.988519\n",
      "INFO:root:Epoch[0] Batch [1248]\tSpeed: 32748.70 samples/sec\tTrain-RMSE=0.971257\n",
      "INFO:root:Epoch[0] Resetting Data Iterator\n",
      "INFO:root:Epoch[0] Time cost=2.507\n",
      "INFO:root:Epoch[0] Validation-RMSE=1.019258\n",
      "INFO:root:Epoch[1] Batch [312]\tSpeed: 32408.79 samples/sec\tTrain-RMSE=0.950525\n",
      "INFO:root:Epoch[1] Batch [624]\tSpeed: 32447.41 samples/sec\tTrain-RMSE=0.953668\n",
      "INFO:root:Epoch[1] Batch [936]\tSpeed: 32325.17 samples/sec\tTrain-RMSE=0.948842\n",
      "INFO:root:Epoch[1] Batch [1248]\tSpeed: 32282.63 samples/sec\tTrain-RMSE=0.958244\n",
      "INFO:root:Epoch[1] Resetting Data Iterator\n",
      "INFO:root:Epoch[1] Time cost=2.504\n",
      "INFO:root:Epoch[1] Validation-RMSE=0.988164\n",
      "INFO:root:Epoch[2] Batch [312]\tSpeed: 32210.61 samples/sec\tTrain-RMSE=0.950271\n",
      "INFO:root:Epoch[2] Batch [624]\tSpeed: 32531.03 samples/sec\tTrain-RMSE=0.945151\n",
      "INFO:root:Epoch[2] Batch [936]\tSpeed: 32286.91 samples/sec\tTrain-RMSE=0.945528\n",
      "INFO:root:Epoch[2] Batch [1248]\tSpeed: 32432.80 samples/sec\tTrain-RMSE=0.951414\n",
      "INFO:root:Epoch[2] Resetting Data Iterator\n",
      "INFO:root:Epoch[2] Time cost=2.506\n",
      "INFO:root:Epoch[2] Validation-RMSE=0.971108\n",
      "INFO:root:Epoch[3] Batch [312]\tSpeed: 32375.11 samples/sec\tTrain-RMSE=0.949024\n",
      "INFO:root:Epoch[3] Batch [624]\tSpeed: 32215.49 samples/sec\tTrain-RMSE=0.948121\n",
      "INFO:root:Epoch[3] Batch [936]\tSpeed: 32142.47 samples/sec\tTrain-RMSE=0.931569\n",
      "INFO:root:Epoch[3] Batch [1248]\tSpeed: 32337.05 samples/sec\tTrain-RMSE=0.946255\n",
      "INFO:root:Epoch[3] Resetting Data Iterator\n",
      "INFO:root:Epoch[3] Time cost=2.512\n",
      "INFO:root:Epoch[3] Validation-RMSE=0.979659\n",
      "INFO:root:Epoch[4] Batch [312]\tSpeed: 32381.31 samples/sec\tTrain-RMSE=0.937760\n",
      "INFO:root:Epoch[4] Batch [624]\tSpeed: 32408.33 samples/sec\tTrain-RMSE=0.939204\n",
      "INFO:root:Epoch[4] Batch [936]\tSpeed: 32315.08 samples/sec\tTrain-RMSE=0.948780\n",
      "INFO:root:Epoch[4] Batch [1248]\tSpeed: 32315.90 samples/sec\tTrain-RMSE=0.946596\n",
      "INFO:root:Epoch[4] Resetting Data Iterator\n",
      "INFO:root:Epoch[4] Time cost=2.508\n",
      "INFO:root:Epoch[4] Validation-RMSE=0.989020\n",
      "INFO:root:Epoch[5] Batch [312]\tSpeed: 32452.73 samples/sec\tTrain-RMSE=0.931686\n",
      "INFO:root:Epoch[5] Batch [624]\tSpeed: 33032.10 samples/sec\tTrain-RMSE=0.936187\n",
      "INFO:root:Epoch[5] Batch [936]\tSpeed: 32601.67 samples/sec\tTrain-RMSE=0.938291\n",
      "INFO:root:Epoch[5] Batch [1248]\tSpeed: 31272.81 samples/sec\tTrain-RMSE=0.940872\n",
      "INFO:root:Epoch[5] Resetting Data Iterator\n",
      "INFO:root:Epoch[5] Time cost=2.510\n",
      "INFO:root:Epoch[5] Validation-RMSE=0.959602\n",
      "INFO:root:Epoch[6] Batch [312]\tSpeed: 31634.43 samples/sec\tTrain-RMSE=0.920564\n",
      "INFO:root:Epoch[6] Batch [624]\tSpeed: 31489.65 samples/sec\tTrain-RMSE=0.942716\n",
      "INFO:root:Epoch[6] Batch [936]\tSpeed: 31589.54 samples/sec\tTrain-RMSE=0.943799\n",
      "INFO:root:Epoch[6] Batch [1248]\tSpeed: 31413.52 samples/sec\tTrain-RMSE=0.942387\n",
      "INFO:root:Epoch[6] Resetting Data Iterator\n",
      "INFO:root:Epoch[6] Time cost=2.570\n",
      "INFO:root:Epoch[6] Validation-RMSE=1.002785\n",
      "INFO:root:Epoch[7] Batch [312]\tSpeed: 31559.69 samples/sec\tTrain-RMSE=0.931971\n",
      "INFO:root:Epoch[7] Batch [624]\tSpeed: 31461.23 samples/sec\tTrain-RMSE=0.936018\n",
      "INFO:root:Epoch[7] Batch [936]\tSpeed: 31464.00 samples/sec\tTrain-RMSE=0.935134\n",
      "INFO:root:Epoch[7] Batch [1248]\tSpeed: 31442.45 samples/sec\tTrain-RMSE=0.939548\n",
      "INFO:root:Epoch[7] Resetting Data Iterator\n",
      "INFO:root:Epoch[7] Time cost=2.573\n",
      "INFO:root:Epoch[7] Validation-RMSE=0.997704\n",
      "INFO:root:Epoch[8] Batch [312]\tSpeed: 31483.89 samples/sec\tTrain-RMSE=0.913341\n",
      "INFO:root:Epoch[8] Batch [624]\tSpeed: 31470.20 samples/sec\tTrain-RMSE=0.932365\n",
      "INFO:root:Epoch[8] Batch [936]\tSpeed: 31448.09 samples/sec\tTrain-RMSE=0.937986\n",
      "INFO:root:Epoch[8] Batch [1248]\tSpeed: 31527.63 samples/sec\tTrain-RMSE=0.941808\n",
      "INFO:root:Epoch[8] Resetting Data Iterator\n",
      "INFO:root:Epoch[8] Time cost=2.578\n",
      "INFO:root:Epoch[8] Validation-RMSE=0.965740\n",
      "INFO:root:Epoch[9] Batch [312]\tSpeed: 31604.54 samples/sec\tTrain-RMSE=0.923620\n",
      "INFO:root:Epoch[9] Batch [624]\tSpeed: 31263.40 samples/sec\tTrain-RMSE=0.926735\n",
      "INFO:root:Epoch[9] Batch [936]\tSpeed: 31391.19 samples/sec\tTrain-RMSE=0.919276\n",
      "INFO:root:Epoch[9] Batch [1248]\tSpeed: 31518.78 samples/sec\tTrain-RMSE=0.931024\n",
      "INFO:root:Epoch[9] Resetting Data Iterator\n",
      "INFO:root:Epoch[9] Time cost=2.579\n",
      "INFO:root:Epoch[9] Validation-RMSE=0.950934\n"
     ]
    }
   ],
   "source": [
    "# @@@ AUTOTEST_OUTPUT_IGNORED_CELL\n",
    "def get_one_layer_dropout_mlp(hidden, k):\n",
    "    # input\n",
    "    user = mx.symbol.Variable('user')\n",
    "    item = mx.symbol.Variable('item')\n",
    "    score = mx.symbol.Variable('score')\n",
    "    # user latent features\n",
    "    user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)\n",
    "    user = mx.symbol.Activation(data = user, act_type=\"relu\")\n",
    "    user = mx.symbol.FullyConnected(data = user, num_hidden = hidden)\n",
    "    user = mx.symbol.Dropout(data=user, p=0.5)\n",
    "    # item latent features\n",
    "    item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)\n",
    "    item = mx.symbol.Activation(data = item, act_type=\"relu\")\n",
    "    item = mx.symbol.FullyConnected(data = item, num_hidden = hidden)\n",
    "    item = mx.symbol.Dropout(data=item, p=0.5)    \n",
    "    # predict by the inner product\n",
    "    pred = user * item\n",
    "    pred = mx.symbol.sum_axis(data = pred, axis = 1)\n",
    "    pred = mx.symbol.Flatten(data = pred)\n",
    "    # loss layer\n",
    "    pred = mx.symbol.LinearRegressionOutput(data = pred, label = score)\n",
    "    return pred\n",
    "train(get_one_layer_mlp(256, 512), batch_size=64, num_epoch=10, learning_rate=.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Acknowledgement\n",
    "\n",
    "This tutorial is based on examples from [xlvector/github](https://github.com/xlvector/)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
